
class Grover:

    def __init__(self, eq_set, eq_value, num_variable):
        self.eq_set = eq_set
        self.eq_value = eq_value
        self.num_variable = num_variable
    

    def black_box(self, circuit, qreg_q):
        """
            功能：self.circuit映射(黑盒)
        """
        # 总共所需要的量子比特数量
        total_bit = self.num_variable + len(self.eq_set) + 1

        circuit.x(qreg_q[total_bit-1])
        circuit.h(qreg_q[total_bit-1])
    
        for i in range(len(self.eq_set)):
            for j in range(len(self.eq_set[i])):
                if type(self.eq_set[i][j]) == int:
                    circuit.cx(qreg_q[self.eq_set[i][j]], qreg_q[i+self.num_variable])
                elif type(self.eq_set[i][j]) == tuple:
                    circuit.ccx(qreg_q[self.eq_set[i][j][0]], qreg_q[self.eq_set[i][j][1]], qreg_q[i+self.num_variable])
                else:
                    print("error")
            if self.eq_value[i] == 0:
                circuit.x(qreg_q[i+self.num_variable])

        circuit.mct(qreg_q[self.num_variable:total_bit-1], qreg_q[total_bit-1])

        #-------------------还原辅助比特----------------
        #circuit.reset(qreg_q[self.num_variable:total_bit-1])

        circuit.h(qreg_q[total_bit-1])
        circuit.x(qreg_q[total_bit-1])

        for i in range(len(self.eq_set)):
            for j in range(len(self.eq_set[i])):
                if type(self.eq_set[i][j]) == int:
                    circuit.cx(qreg_q[self.eq_set[i][j]], qreg_q[i+self.num_variable])
                elif type(self.eq_set[i][j]) == tuple:
                    circuit.ccx(qreg_q[self.eq_set[i][j][0]], qreg_q[self.eq_set[i][j][1]], qreg_q[i+self.num_variable])
                else:
                    print("error")
            if self.eq_value[i] == 0:
                circuit.x(qreg_q[i+self.num_variable])

        
    def inversion_about_average(self, circuit, qreg_q):
        """
            功能：振幅平均取反
        """
        circuit.h(qreg_q[:self.num_variable])
        circuit.barrier()
        circuit.x(qreg_q[:self.num_variable])
        circuit.h(qreg_q[self.num_variable-1])
        
        circuit.mct(qreg_q[:self.num_variable-1], qreg_q[self.num_variable-1])

        circuit.h(qreg_q[self.num_variable-1])
        circuit.x(qreg_q[:self.num_variable])
        circuit.barrier()
        circuit.h(qreg_q[:self.num_variable])

'''
    def run(self):
        qreg_q = QuantumRegister(self.num_variable + len(self.eq_set) + 1, 'q')
        creg_c = ClassicalRegister(self.num_variable, 'c')
        circuit = QuantumCircuit(qreg_q, creg_c)

        circuit.h(qreg_q[:self.num_variable])
        z = int(((2**self.num_variable)**0.5)/4*pi)
        for i in range(z):
            self.black_box(circuit, qreg_q)
            self.inversion_about_average(circuit, qreg_q)
        circuit.measure(qreg_q[:self.num_variable], creg_c[:self.num_variable])

        return circuit
'''